60e58a
@@ -16,27 +16,21 @@
 
 package org.springframework.web.socket.server.standard;
 
-import java.lang.reflect.Method;
-import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.LinkedHashSet;
 import java.util.List;
-import java.util.Map;
+import java.util.Set;
+import javax.servlet.ServletContext;
 import javax.websocket.DeploymentException;
 import javax.websocket.server.ServerContainer;
 import javax.websocket.server.ServerEndpoint;
 import javax.websocket.server.ServerEndpointConfig;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.InitializingBean;
 import org.springframework.beans.factory.config.BeanPostProcessor;
 import org.springframework.context.ApplicationContext;
-import org.springframework.context.ApplicationContextAware;
 import org.springframework.util.Assert;
-import org.springframework.util.ClassUtils;
-import org.springframework.util.ReflectionUtils;
+import org.springframework.web.context.support.WebApplicationObjectSupport;
 
 /**
  * Detects beans of type {@link javax.websocket.server.ServerEndpointConfig} and registers
@@ -50,24 +44,36 @@
import org.springframework.util.ReflectionUtils;
  * done with the help of the {@code <absolute-ordering>} element in web.xml.
  *
  * @author Rossen Stoyanchev
+ * @author Juergen Hoeller
  * @since 4.0
  * @see ServerEndpointRegistration
  * @see SpringConfigurator
  * @see ServletServerContainerFactoryBean
  */
-public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware {
-
-	private static final Log logger = LogFactory.getLog(ServerEndpointExporter.class);
+public class ServerEndpointExporter extends WebApplicationObjectSupport implements BeanPostProcessor, InitializingBean {
 
+	private ServerContainer serverContainer;
 
-	private final List<Class<?>> annotatedEndpointClasses = new ArrayList<Class<?>>();
+	private List<Class<?>> annotatedEndpointClasses;
 
-	private final List<Class<?>> annotatedEndpointBeanTypes = new ArrayList<Class<?>>();
+	private Set<Class<?>> annotatedEndpointBeanTypes;
 
-	private ApplicationContext applicationContext;
 
-	private ServerContainer serverContainer;
+	/**
+	 * Set the JSR-356 {@link ServerContainer} to use for endpoint registration.
+	 * If not set, the container is going to be retrieved via the {@code ServletContext}.
+	 * @since 4.1
+	 */
+	public void setServerContainer(ServerContainer serverContainer) {
+		this.serverContainer = serverContainer;
+	}
 
+	/**
+	 * Return the JSR-356 {@link ServerContainer} to use for endpoint registration.
+	 */
+	protected ServerContainer getServerContainer() {
+		return this.serverContainer;
+	}
 
 	/**
 	 * Explicitly list annotated endpoint types that should be registered on startup. This
@@ -76,17 +82,19 @@
public class ServerEndpointExporter implements InitializingBean, BeanPostProcess
 	 * @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types
  	 */
 	public void setAnnotatedEndpointClasses(Class<?>... annotatedEndpointClasses) {
-		this.annotatedEndpointClasses.clear();
-		this.annotatedEndpointClasses.addAll(Arrays.asList(annotatedEndpointClasses));
+		this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses);
 	}
 
 	@Override
-	public void setApplicationContext(ApplicationContext applicationContext) {
-		this.applicationContext = applicationContext;
-		this.serverContainer = getServerContainer();
-		Map<String, Object> beans = applicationContext.getBeansWithAnnotation(ServerEndpoint.class);
-		for (String beanName : beans.keySet()) {
-			Class<?> beanType = applicationContext.getType(beanName);
+	protected void initApplicationContext(ApplicationContext context) {
+		// Initializes ServletContext given a WebApplicationContext
+		super.initApplicationContext(context);
+
+		// Retrieve beans which are annotated with @ServerEndpoint
+		this.annotatedEndpointBeanTypes = new LinkedHashSet<Class<?>>();
+		String[] beanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class);
+		for (String beanName : beanNames) {
+			Class<?> beanType = context.getType(beanName);
 			if (logger.isInfoEnabled()) {
 				logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type");
 			}
@@ -94,66 +102,72 @@
public class ServerEndpointExporter implements InitializingBean, BeanPostProcess
 		}
 	}
 
-	protected ServerContainer getServerContainer() {
-		Class<?> servletContextClass;
-		try {
-			servletContextClass = ClassUtils.forName("javax.servlet.ServletContext", getClass().getClassLoader());
+	@Override
+	protected void initServletContext(ServletContext servletContext) {
+		if (this.serverContainer == null) {
+			this.serverContainer =
+					(ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer");
+		}
+	}
+
+
+	@Override
+	public void afterPropertiesSet() {
+		Assert.state(getServerContainer() != null, "javax.websocket.server.ServerContainer not available");
+		registerEndpoints();
+	}
+
+	/**
+	 * Actually register the endpoints. Called by {@link #afterPropertiesSet()}.
+	 * @since 4.1
+	 */
+	protected void registerEndpoints() {
+		Set<Class<?>> endpointClasses = new LinkedHashSet<Class<?>>();
+		if (this.annotatedEndpointClasses != null) {
+			endpointClasses.addAll(this.annotatedEndpointClasses);
+		}
+		if (this.annotatedEndpointBeanTypes != null) {
+			endpointClasses.addAll(this.annotatedEndpointBeanTypes);
 		}
-		catch (Throwable ex) {
-			return null;
+		for (Class<?> endpointClass : endpointClasses) {
+			registerEndpoint(endpointClass);
 		}
+	}
 
+	private void registerEndpoint(Class<?> endpointClass) {
 		try {
-			Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext");
-			Object servletContext = getter.invoke(this.applicationContext);
-			Method attrMethod = ReflectionUtils.findMethod(servletContextClass, "getAttribute", String.class);
-			return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer");
+			if (logger.isInfoEnabled()) {
+				logger.info("Registering @ServerEndpoint type: " + endpointClass);
+			}
+			getServerContainer().addEndpoint(endpointClass);
 		}
-		catch (Exception ex) {
-			throw new IllegalStateException(
-					"Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex);
+		catch (DeploymentException ex) {
+			throw new IllegalStateException("Failed to register @ServerEndpoint type " + endpointClass, ex);
 		}
 	}
 
-	@Override
-	public void afterPropertiesSet() throws Exception {
-		Assert.state(this.serverContainer != null, "javax.websocket.server.ServerContainer not available");
-
-		List<Class<?>> allClasses = new ArrayList<Class<?>>(this.annotatedEndpointClasses);
-		allClasses.addAll(this.annotatedEndpointBeanTypes);
 
-		for (Class<?> clazz : allClasses) {
-			try {
-				logger.info("Registering @ServerEndpoint type " + clazz);
-				this.serverContainer.addEndpoint(clazz);
-			}
-			catch (DeploymentException e) {
-				throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e);
-			}
-		}
+	@Override
+	public Object postProcessBeforeInitialization(Object bean, String beanName) {
+		return bean;
 	}
 
 	@Override
-	public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
+	public Object postProcessAfterInitialization(Object bean, String beanName) {
 		if (bean instanceof ServerEndpointConfig) {
-			ServerEndpointConfig sec = (ServerEndpointConfig) bean;
+			ServerEndpointConfig endpointConfig = (ServerEndpointConfig) bean;
 			try {
 				if (logger.isInfoEnabled()) {
 					logger.info("Registering bean '" + beanName +
-							"' as javax.websocket.Endpoint under path " + sec.getPath());
+							"' as javax.websocket.Endpoint under path " + endpointConfig.getPath());
 				}
-				getServerContainer().addEndpoint(sec);
+				getServerContainer().addEndpoint(endpointConfig);
 			}
-			catch (DeploymentException e) {
-				throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e);
+			catch (DeploymentException ex) {
+				throw new IllegalStateException("Failed to deploy Endpoint bean with name '" + bean + "'", ex);
 			}
 		}
 		return bean;
 	}
 
-	@Override
-	public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
-		return bean;
-	}
-
 }
